#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 25 21:35:12 2022

@author: matthewroberts
"""

import os
import sys
import numpy as np
import glob

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
import matplotlib.gridspec as gridspec
import matplotlib.colors as colors
import matplotlib.cm as cm
from matplotlib.colors import ListedColormap, LinearSegmentedColormap

import cartopy.crs as ccrs
import pyproj

import datetime as dt
import wrf
from netCDF4 import Dataset
import pandas as pd
from pyproj import Geod
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cartopy.io.shapereader as shpreader
from cartopy.io.img_tiles import GoogleTiles
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter

import warnings
warnings.filterwarnings('ignore')

start_time = dt.datetime.now()
print('\nProgram Started: {0}'.format(start_time))
print('===================')

def relax_zone_remover (input, sr):
    # remove extra points in fire grid
    output = input
    for _ in range(sr):
        output = np.delete(output, -1, 0)
        output = np.delete(output, -1, 1)
    return output

#########################
# Inputs
#########################

# firename = 'bear_fire' # for naming directories/figures
# extent = [-121.5,-120.9, 39.48,39.9] #bear fire
# idx0 = 0
# idx1 = -1

firename = 'caldor_fire' # for naming directories/figures
extent = [-120.61, -120.28, 38.52, 38.76] #caldor fire
idx0 = 0
idx1 = -2

keyword = '_FUELx8'

mainpath = '/Users/matthewroberts/Documents/Projects/LEAPHI'
filepath = mainpath+'/data/'+firename+'/'+keyword+'/' #data files location
savepath = mainpath+'/plots/'+firename+'/' #where to save figure
filelist = sorted(glob.glob(filepath+'/wrfout_d03*'))

# dont plot initial perimeter
fuelpath = sorted(glob.glob(mainpath+'/data/'+firename+'/_NOFIRE'+'/wrfout_d03*'))

# get last file and perimeter
wrf_file = Dataset(fuelpath[idx0],mode='r') # for initial fuel map
wrf_file2 = Dataset(filelist[idx1],mode='r') #for final perimeter

# Create timestamp and time array
time = wrf.extract_times(wrf_file,wrf.ALL_TIMES)
tstamp = str(pd.Timestamp(time[0]))
print('\nOpening '+tstamp+'...')

# Standard lat/lon grid
terrain = wrf.getvar(wrf_file, "HGT", timeidx=-1)
lat = wrf.getvar(wrf_file, "XLAT", timeidx=-1)
lon = wrf.getvar(wrf_file, "XLONG", timeidx=-1)

# Fire grid variables
lfn1 = wrf.getvar(wrf_file2, 'LFN', timeidx=-1)
fuel =  wrf.getvar(wrf_file, "NFUEL_CAT", timeidx=-1)
# loading lat/lons of fire mesh
yf1 = wrf.getvar(wrf_file, 'FXLAT', timeidx=-1)
xf1 = wrf.getvar(wrf_file, 'FXLONG', timeidx=-1)

# subgrid ratio
sr = int(wrf_file.dimensions['west_east_subgrid'].size)/int(wrf_file.dimensions['west_east_stag'].size)
# removing the relaxation zones of the fire grid variables
lfn1 = relax_zone_remover(lfn1, int(sr))
fuel = relax_zone_remover(fuel, int(sr))
xf1 = relax_zone_remover(xf1, int(sr))
yf1 = relax_zone_remover(yf1, int(sr))
# # match lfn to met grid
# fuel = fuel[::int(sr),::int(sr)]

# Fuel codes
fuelVals = [14,91,92,93,98,99,101,102,103,104,105,106,107,108,109,121,
            122,123,124,141,142,143,144,145,146,147,148,149,161,162,
            163,164,165,181,182,183,184,185,186,187,188,189]
fuelCats = ['NF','NB1','NB2','NB3','NB8','NB9','GR1','GR2','GR3','GR4',
            'GR5','GR6','GR7','GR8','GR9','GS1','GS2','GS3','GS4',
            'SH1','SH2','SH3','SH4','SH5','SH6','SH7','SH8','SH9',
            'TU1','TU2','TU3','TU4','TU5','TL1','TL2','TL3','TL4',
            'TL5','TL6','TL7','TL8','TL9']

# get/organize fuel data
typelist = np.unique(fuel)

# find which fuels to use
idx = np.intersect1d(fuelVals,typelist, return_indices=True)
fuelkey = np.array(fuelCats)[idx[1]]

# Create new fuel array in increments of 1, starting at 0
fuel_idx = np.asarray(range(len(np.unique(fuel))))
# Old fuel codes for labeling
fuel_vals = np.unique(fuel).astype(int)

fuelcopy = fuel.data
# Replace fuel codes w/ new
for i in fuel_idx:
    fuelcopy = np.where(fuelcopy == fuel_vals[i], i, fuelcopy)

# domain totals
firefuel = fuel.data
unique, counts1 = np.unique(firefuel, return_counts=True)
stats1 = dict(zip(unique, counts1))
pct1 = stats1[165]/(np.sum(counts1)-stats1[14])

# fire area totals
firefuel[lfn1 > 0] = 0
unique, counts2 = np.unique(firefuel, return_counts=True)
stats2 = dict(zip(unique, counts2))
pct2 = stats2[165]/(np.sum(counts2)-stats2[0]-stats2[14])

print('Domain TU5:  '+str(int(pct1*100))+'%')
print('Fire TU5:  '+str(int(pct2*100))+'%')


#%%
#########################
# Plotting
#########################
print('Plotting...', end="", flush=True)
fig = plt.figure(1,figsize=(9,9))
ax1 = fig.add_subplot(111, projection=ccrs.PlateCarree()) #rows,cols (1)
ax1.set_extent(extent, crs=ccrs.PlateCarree())

# Draw the fuelmap
cmap1 = cm.get_cmap('terrain',lut=len(fuel_idx))
cmap1.set_under('k')
ticks = fuel_idx

# fuelcopy[lfn1 > 0] = np.nan
# fuelcopy[fuel.data == 165] = 999

# val = 19
# fuelcopy = np.ma.masked_less(fuelcopy, val)
# fuelcopy = np.ma.masked_greater(fuelcopy, val)

fuelplot = ax1.pcolormesh(xf1, yf1, fuelcopy, cmap=cmap1, vmin=fuel_idx.min()-.5, vmax=fuel_idx.max()+.5)

perim = ax1.contour(xf1, yf1, lfn1, 0, linewidths=3, colors='k')
# tu5 = ax1.contour(xf1, yf1, fuelcopy, 165, linewidths=1, colors='k')

ax1.set_xticks(np.linspace(extent[0],extent[1],7),crs=ccrs.PlateCarree()) # set longitude indicators
ax1.set_yticks(np.linspace(extent[2],extent[3],8)[1:],crs=ccrs.PlateCarree()) # set latitude indicators
lon_formatter = LongitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}',dateline_direction_label=True) # format lons
lat_formatter = LatitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}') # format lats
ax1.xaxis.set_major_formatter(lon_formatter) # set lons
ax1.yaxis.set_major_formatter(lat_formatter) # set lats
ax1.xaxis.set_tick_params(width=2,length=5,direction='in',labelsize=12, zorder=12, rotation=20, bottom=True, top=True)
ax1.yaxis.set_tick_params(width=2,length=5,direction='in',labelsize=12, zorder=12, left=True, right=True)

#Download county shapefiles
reader = shpreader.Reader('/Users/matthewroberts/Documents/Data/shapefiles/countyl010g_shp/countyl010g.shp')
counties = list(reader.geometries())
COUNTIES = cfeature.ShapelyFeature(counties, ccrs.PlateCarree())
ax1.add_feature(COUNTIES, facecolor='none', edgecolor='k', linewidth=1, alpha=.4, zorder=10)

ax1.outline_patch.set_linewidth(2)
ax1.outline_patch.set_zorder(99)

ax1.set_xlabel("Longitude", fontsize=12)
ax1.set_ylabel("Latitude", fontsize=12)

color_idx = np.arange(len(fuelkey))
cbar = plt.colorbar(fuelplot, ax=ax1, shrink=.55, ticks=ticks)
cbar.set_ticklabels(fuelkey)
    
# plt.colorbar(fuelplot)

# plt.savefig(savepath+'/fuelmap.png',bbox_inches='tight',dpi=900)

plt.show()

sys.exit()





# vectors
wf = ax1.quiver(x_array[::regrid_fact], interp_height[::regrid_fact2],
               comp_wind[::regrid_fact2,::regrid_fact],
               w_cross[::regrid_fact2,::regrid_fact],
               color='k', scale=500, width=0.003, pivot='tail', alpha=0.9, zorder=10)

# key
qk = ax1.quiverkey(wf, .96, 1.05, 20, r'$20 \frac{m}{s}$', labelpos='E',
                    coordinates='axes',fontproperties=dict(size=16))

# plot terrain
ax1.plot(x_array,ter_cross,color='k',linewidth=4.,zorder=10)
ax1.fill_between(x_array,ter_cross,y2=0, color='sienna', zorder=10)
# Plot fire area
ax1.plot(lfn_x,lfn_y,color='k',linewidth=7.,zorder=10)
ax1.plot(lfn_x,lfn_y,color='orangered',linewidth=4.,zorder=11)

# Set the x-ticks to use latitude and longitude labels.
tick_space = 90 #higher values = fewer ticks
x_ticks = np.arange(len(lat_cross))
x_labels = ['{:.2f}, {:.2f}'.format(lon,lat) for lon,lat in coord_pairs]
ax1.set_xticks(x_ticks[tick_space::tick_space],x_labels[tick_space::tick_space])#fontsize=10, rotation=20, horizontalalignment='right'
# Set x/y lims
ax1.set_xlim(x_array[0],x_array[-1])
ax1.set_ylim(0,15000)

# Set border and ticks around plot
ax1.spines['top'].set_zorder(15)
ax1.spines['bottom'].set_zorder(15)
ax1.spines['left'].set_zorder(15)
ax1.spines['right'].set_zorder(15)
# I can't figure out why the ticks won't plot on top of the data. 
# Nothing seems to work, so if you figure it out let me know :)
ax1.tick_params(axis="y",direction="in",zorder=15)
ax1.tick_params(axis="x",direction="in",zorder=15)
# Title and axis labels
ax1.set_title(tstamp+' UTC', loc='left', fontsize=14, fontweight='bold',zorder=10)
ax1.set_ylabel("Height [m]", fontsize=11)

# Inset map
ax2 = fig.add_subplot(spec[0:13,21:31], projection=ccrs.PlateCarree()) # spec[rows,cols]
# extent = [-121.5, -120.85, 39.44, 39.91]
# ax2.set_extent(extent, crs=ccrs.PlateCarree())

# terrain in inset
terMap = ax2.contourf(lon,lat,terrain, cmap='terrain', extend='both',
                      levels=np.arange(-400,2400,100), zorder=7,
                      transform=ccrs.PlateCarree())
# fire perimeter in inset
perimPlot = ax2.contour(xf1, yf1, lfn1, levels=[0], colors='k', linewidth=3,
                        zorder=8, transform=ccrs.PlateCarree())
# plot extent of x-sections (dashed lines on inset)
for jjj in [0,-1]:
    Xsect = ax2.plot(_lon_cross[jjj,:],_lat_cross[jjj,:],linewidth=3,color='white',
                     alpha=.5, zorder=10, transform=ccrs.PlateCarree())
    Xsect = ax2.plot(_lon_cross[jjj,:],_lat_cross[jjj,:],linewidth=1,color='k',
                     linestyle='--', zorder=10, transform=ccrs.PlateCarree())
ax2.set_xticks([])
ax2.set_yticks([])
# Set border around plot
ax2.spines['top'].set_zorder(15)
ax2.spines['bottom'].set_zorder(15)
ax2.spines['left'].set_zorder(15)
ax2.spines['right'].set_zorder(15)

plt.show()
sys.exit()

# create timestamp for saving files
savetime = dt.datetime.strptime(tstamp, "%Y-%m-%d %H:%M:%S")
savetime = dt.datetime.strftime(savetime, "%Y%m%d_%H%M")
# save
plt.savefig(savepath+'/'+keyword+'_'+savetime+'.png',bbox_inches='tight',dpi=120)
plt.close('all')
print('Saved '+keyword+'_'+savetime+'.png')


end_time = dt.datetime.now()
print('===================')
print('Total elapsed time: {0}\n'.format(end_time-start_time))